from torch.optim import Adam, SGD

def  init_opitm(params, opt):
    optim_list = {
        'adam' : Adam,
        'sgd'  : SGD,
    }
    optim = opt.optim.lower()
    if optim not in optim_list:
        raise AttributeError('wrong name of optim: %s' % optim)
    
    if optim == 'adam':
        return optim_list[optim](
            params, lr=opt.lr, betas=tuple(opt.betas), weight_decay=opt.weight_decay
        )
    elif optim == 'sgd':
        return optim_list[optim](
            params, lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay
        )